### #######################################
### Does variable selection using glmnet
### Output is a formula
### #######################################

here::i_am("newR/04_varsel.R")

library(here)
library(parallel)
library(glmnet)
library(tidyverse)
library(doMC)
library(knitr)
library(broom)
set.seed(88516)

pretty_summary <- function(x) {
    xbar <- mean(x, na.rm = TRUE)
    xsd <- sd(x, na.rm = TRUE)
    xbar <- signif(xbar, 3)
    xsd <- signif(xsd, 3)
    ret <- paste0(xbar, " (", xsd, ")")
    ret
}

glmnet_to_formula <- function(obj, depvar = "y") {
    terms <- rownames(obj[[1]])[which(obj[[1]] != 0)]
    ## Remove intercept
    terms <- terms[-1]
    ## Handle Scotland and Wales
    terms <- sub("NationScotland", "Nation", terms)
    terms <- sub("NationWales", "Nation", terms)
    
    terms <- paste0(unique(terms), collapse = " + ")
    f <- paste0(depvar,
                " ~ ",
                terms)
    as.formula(f)
}

ncores <- parallel::detectCores() - 2
registerDoMC(cores = ncores)

source(here::here("newR", "00_helpers.R"))

dat <- readRDS(here::here("working",
                          "tidy_model_data.rds"))

pvars <- c("Con_BE", "Lab_BE", "Lib_BE", "Nat_BE", "Oth2_BE")
for (p in pvars) { 
    dat[,p] <- replace(dat[,p],
                       dat[,p] <= 0,
                       1 / 40000)
}

### Make sure everything sums to 1
dat[,pvars] <- dat[,pvars] / rowSums(dat[,pvars])

y <- as.matrix(cbind(dat[,pvars]))

### Create matrix with all pairwise interactions in it for glmnet
x <- model.matrix(~ (Lab_GE_sc + Lib_GE_sc + Nat_GE_sc + Oth_GE_sc +
                     LabGovt_sc + LibGovt_sc +
                     LabInc_sc + LibInc_sc + RPI_sc + winter_sc +
                     Turn_at_GE_sc + Nation + 
                     LabPollChg_sc + LibPollChg_sc + OthPollChg_sc + 
                     ConCand_BE_sc + LabCand_BE_sc +
                     LibCand_BE_sc + NatCand_BE_sc + OthCand_BE_sc +
                     ConCand_GE_sc + LabCand_GE_sc +
                     LibCand_GE_sc + NatCand_GE_sc + OthCand_GE_sc +
                     IncCandidateLib_sc + IncCandidateOther_sc) ^ 2,
                  data = dat)


### Do the cross-validation

mfit <- cv.glmnet(x, y, family = "mgaussian")

### Work out the number of non-zero coefficients at different points
### Replicate the graphs that glmnet does, but in ggplot

mft <- tidy(mfit)
vars_min <- mfit$nzero[which(mfit$lambda == mfit$lambda.min)]
vars_1se <- mfit$nzero[which(mfit$lambda == mfit$lambda.1se)]

p1 <- ggplot(mft,
             aes(x = log(lambda),
                 y = estimate,
                 ymin = conf.low,
                 ymax = conf.high)) +
    geom_vline(xintercept = log(mfit$lambda.min), linetype = 1) +
    geom_vline(xintercept = log(mfit$lambda.1se), linetype = 2) +
    annotate(geom = "text",
             x = log(mfit$lambda.min),
             y = .6,
             size = 3,
             label = paste0("(A) \n",vars_min, " \nvariables \nselected "),
             hjust = "right") + 
    annotate(geom = "text",
             x = log(mfit$lambda.1se),
             y = .6,
             size = 3,
             label = paste0(" (B)\n ", vars_1se, "\n variables\n selected"),
             hjust = "left") + 
    geom_pointrange(colour = "darkgrey", fatten = 0.5) +
    scale_x_continuous("log(λ)") +
    scale_y_log10("MSE") + 
    theme_bw()

grDevices::cairo_pdf(here::here("article/figure",
                                "cvpath.pdf"),
                     width = 8, height = 4)
p1
dev.off()

### The plot shows an odd hump in between values
### we're going to stay on the right hand side of the hump
### but take the closest value before the trend starts decreasing again.

### What are the values of lambda, in order
lambdas <- mfit$lambda

cobj_min <- coef(mfit, s = "lambda.min")
cobj_1se <- coef(mfit, s = "lambda.1se")

f_min <- glmnet_to_formula(cobj_min)
f_1se <- glmnet_to_formula(cobj_1se)


save(f_min, f_1se,
     file = here::here("working",
                       "varsel_formulas.rdata"))

### Can we get the inclusion order?
### Do it first for min
nonzero_min <- ncol(model.matrix(f_min, data = dat)) - 1
incl_ord <- as.matrix(mfit$glmnet.fit$beta$Con_BE)
incl_ord <- apply(incl_ord, 1, function(x)which(x != 0)[1])
pos <- which.min(abs(lambdas - mfit$lambda.min))
### NA out ones with values > pos
incl_ord[which(incl_ord > pos)] <- NA
incl_ord <- na.omit(incl_ord)
incl_ord <- rank(incl_ord, ties.method = "min")

incl_df <- data.frame(term = names(incl_ord),
                      order = incl_ord)

incl_df <- incl_df %>%
    mutate(term = gsub("_sc", "", term)) %>% 
    separate(term, sep = ":",
             into = c("main_term", "interaction_term")) %>%
    mutate(interaction_term = replace_na(interaction_term, "Main term")) %>%
    mutate(main_term = factor(main_term, levels = unique(c(main_term, interaction_term))),
           interaction_term = factor(interaction_term,
                                     levels = unique(c(as.character(main_term),
                                                       interaction_term))))

incl_df <- incl_df %>%
    complete(main_term, interaction_term) %>% 
    pivot_wider(id_cols = main_term, names_from = interaction_term, values_from = order)

incl_df <- incl_df[-which(incl_df$main_term == "Main term"),]

incl_df <- incl_df %>%
    dplyr::select(main_term, `Main term`, everything())

### Arrange the rows
incl_df <- incl_df %>%
    arrange(`Main term`)

### Arrange the columns to reflect this
incl_df <- incl_df[,c("main_term", "Main term",
                      as.character(incl_df$main_term))]

out <- incl_df %>%
    mutate(main_term = dplyr::recode(main_term,
                                     "Lab_GE" = "Labour GE share",
                                     "Lib_GE" = "Liberal GE share",
                                     "Nat_GE" = "Nat. GE share",
                                     "Oth_GE" = "Other GE share",
                                     "LabGovt" = "Labour govt",
                                     "LabInc" = "Labour party inc.",
                                     "Turn_at_GE" = "Turnout at GE",
                                     "LabPollChg" = "Labour poll change",
                                     "LibPollChg" = "Lib poll change",
                                     "OthPollChg" = "Other poll change",
                                     "LibCand_BE" = "Liberal candidate",
                                     "NatCand_BE" = "Nat. candidate",
                                     "OthCand_BE" = "Other candidate(s)",
                                     "LibGovt" = "Lib Dems in govt",
                                     "RPI" = "Inflation",
                                     "NationScotland" = "Scottish seat",
                                     "ConCand_BE" = "Conservative candidate",
                                     "LabCand_BE" = "Labour candidate",
                                     "IncCandidateLib" = "Personal incumbent running as Lib",
                                     "ConCand_GE" = "Con. candidate in GE",
                                     "LibCand_GE"= "Lib. candidate in GE",
                                     "NationWales" = "Welsh seat",
                                     "IncCandidateOther" = "Personal incumbent running as other",
                                     "NatCand_GE"= "Nat. candidate in GE",
                                     "LabCand_GE" = "Labour candidate in GE"))

out$main_term <- paste0("(", 1:20, ") ", out$main_term)


colnames(out)[3:ncol(out)] <- paste0("(", 1:20, ")")
colnames(out)[1] <- "Variable"
opts <- options(knitr.kable.NA = "")

capture.output(kable(out,
                     caption = "Order of inclusion"),
               file = here::here("article/inserts",
                                 "order_of_inclusion_min.md"))

### Now do it for 1se
nonzero_1se <- ncol(model.matrix(f_1se, data = dat)) - 1
incl_ord <- as.matrix(mfit$glmnet.fit$beta$Con_BE)
incl_ord <- apply(incl_ord, 1, function(x)which(x != 0)[1])
### NA out ones with values > pos
incl_ord[which(incl_ord >= nonzero_1se)] <- NA
incl_ord <- na.omit(incl_ord)
incl_ord <- rank(incl_ord, ties.method = "min")
incl_df <- data.frame(term = names(incl_ord),
                      order = incl_ord)

incl_df <- incl_df %>%
    mutate(term = gsub("_sc", "", term)) %>% 
    separate(term, sep = ":",
             into = c("main_term", "interaction_term")) %>%
    mutate(interaction_term = replace_na(interaction_term, "Main term")) %>%
    pivot_wider(id_cols = main_term, names_from = interaction_term, values_from = order)

dat$NationScotland <- as.numeric(dat$Nation == "Scotland")
dat$NationWales <- as.numeric(dat$Nation == "Wales")
### Merge with summary statistics
vars <- c("Lab_GE", "Lib_GE", "Nat_GE", "Oth_GE",
          "LabGovt", "LibGovt",
          "LabInc", "LibInc", "RPI", "winter",
          "Turn_at_GE",
          "LabPollChg", "LibPollChg", "OthPollChg",
          "ConCand_BE", "LabCand_BE",
          "LibCand_BE", "NatCand_BE", "OthCand_BE",
          "ConCand_GE", "LabCand_GE",
          "LibCand_GE", "NatCand_GE", "OthCand_GE",
          "IncCandidateLib", "IncCandidateOther",
          "NationScotland", "NationWales")


smry_df <- dat %>%
    dplyr::select(all_of(vars)) %>%
    summarize_all(pretty_summary) %>%
    pivot_longer(cols = everything(),
                 names_to = "main_term",
                 values_to = "Summary") 

out <- merge(smry_df, incl_df, by = "main_term", all = TRUE) %>%
    arrange(`Main term`)

out <- out %>%
    mutate(main_term = dplyr::recode(main_term,
                                     "Lib_GE"= "Liberal GE performance",
                                     "Lab_GE"= "Labour GE performance",
                                     "Nat_GE"= "Nationalist GE performance",
                                     "Oth_GE" = "Other GE performance",
                                     "LabGovt" = "Labour govt",
                                     "LibGovt" = "Lib Dems in govt",
                                     "LabInc" = "Labour party incumbency",
                                     "IncCandidateOther" = "Personal incumbent running as other",
                                     "Turn_at_GE" = "Turnout at GE",
                                     "LibPollChg" = "Lib poll change",
                                     "LabPollChg" = "Labour poll change",
                                     "OthPollChg" = "Other poll change",
                                     "LibCand_BE" = "Liberal candiddate",
                                     "LabCand_GE" = "Labour candidate in GE",
                                     "OthCand_GE" = "Other candidate in GE",
                                     "LibCand_GE" = "Liberal candidate in GE",
                                     "NatCand_BE" = "National candidate",
                                     "ConCand_BE" = "Conservative candidate",
                                     "OthCand_BE" = "Other candidate",
                                     "NationScotland" = "Scottish seat",
                                     "NationWales" = "Welsh seat",
                                     "NatCand_GE"= "Nationalist GE candidate",
                                     "RPI" = "RPI",
                                     "winter" = "Winter election"))

names(out) <- dplyr::recode(names(out),
                            "Lib_GE"= "Liberal GE performance",
                            "Lab_GE"= "Labour GE performance",
                            "Nat_GE"= "Nationalist GE performance",
                            "Oth_GE" = "Other GE performance",
                            "LabGovt" = "Labour govt",
                            "LibGovt" = "Lib Dems in govt",
                            "LabInc" = "Labour party incumbency",
                            "IncCandidateOther" = "Personal incumbent running as other",
                            "Turn_at_GE" = "Turnout at GE",
                            "LibPollChg" = "Lib poll change",
                            "LabPollChg" = "Labour poll change",
                            "OthPollChg" = "Other poll change",
                            "LibCand_BE" = "Liberal candiddate",
                            "LabCand_GE" = "Labour candidate in GE",
                            "OthCand_GE" = "Other candidate in GE",
                            "LibCand_GE" = "Liberal candidate in GE",
                            "NatCand_BE" = "National candidate",
                            "ConCand_BE" = "Conservative candidate",
                            "OthCand_BE" = "Other candidate",
                            "NationScotland" = "Scottish seat",
                            "NationWales" = "Welsh seat",
                            "NatCand_GE"= "Nationalist GE candidate",
                            "RPI" = "RPI",
                            "winter" = "Winter election",
                            "main_term" = "Variable")
                            
opts <- options(knitr.kable.NA = "")

capture.output(kable(out,
                     caption = "Summary statistics and order of inclusion"),
               file = here::here("article/inserts",
                                 "summary_stats_1se.md"))
